\[ \DeclareMathOperator*{\argmin}{argmin\;} \]
Define the function \[ S_\lambda(y) = \argmin_\theta \frac{1}{2}(y - \theta)^2 + \lambda\lvert \theta \rvert. \] The intuition here is that \(\theta\) is a parameter of a statistical model, and \(y\) is data. The first (quadratic) term rewards good fit to the data, while the second term rewards \(\theta\) for being “simpler” (i.e. closer to zero). \(S_\lambda(y)\) returns an estimate for \(\theta\) that blends these two goals.
First show (in a trivial one- or two-liner) that the quadratic term in the objective above is the negative log likelihood of a Gaussian distribution with mean \(\theta\) and variance 1. Then prove that \[ S_\lambda(y) = \mathrm{sign}(\lvert y \rvert - \lambda)_+ \] Plot this as function of y for a few different parameters of λ. You’ll see how it encourages sparsity in a “soft” way, especially if you compare it to the hard-thresholding function \[ H_\lambda(y) = \begin{cases} y & \text{if} \quad y \geq \lambda \\ 0 & \text{otherwise. } \end{cases} \]
If \[ y\sim N(\theta, 1) \] then \[ \begin{aligned} l(\theta) & = - \log p(y) \\ & = - \log\left((2\pi)^{-1/2} \exp\left\{-\frac{1}{2}(y - \theta)^2\right\}\right) \\ & = \frac{1}{2}(y - \theta)^2 + \textit{const.} \end{aligned} \] This proves the first part. We now want to find \(S_\lambda(y)\). I will use the following from fact from calculus:
| Calculus Fact. Let \(f\) is a function with a minimum \(x^*\) which is piecewise smooth on the real line and continuous everywhere, then either \(x^*\) is a critical point at a differentiable region (meaning \(f'(x^*) =0)\)), or \(x^*\) is one of the isolated non-differentiable points. |
Back to our problem, the function \[ g(\theta) = \frac{1}{2}(y - \theta)^2 + \lambda\lvert \theta \rvert. \] is differentiable at every point except when \(\theta = 0\), it definitly has a minimum since it’s bounded from below by zero, and it is everywhere continuous, so \(S_\lambda(y)\) will either be a critical point in \(\mathbb{R} \setminus \{0\}\), or \(S_\lambda(y) = 0\). The actual solution will depend on the value of \(y\).
We start by computing the possible critical points of \[ g(\theta) = \frac{1}{2}(y - \theta)^2 - \lambda \theta, \]
Case \(\theta = 0\). We see that when \(\theta = 0\), which is always a candidate for a minimum, we have \[ g(0) = \frac{1}{2}y^2. \]
Case \((-\infty , 0)\). On the interval \((0,\infty)\) we have \[ g(\theta) = \frac{1}{2}(y - \theta)^2 - \lambda \theta, \] and a critical point would need to satisty \(g'(\theta)=0\), or equivalently \[ \theta^* = y + \lambda, \] so there is a critical point in \((-\infty, 0)\) only if \[ y < - \lambda, \] otherwise there is no critical point.
Case \((0, \infty)\). An identical reasoning to the previous case leads to the conclusion that theta will be a critical point in \((-\infty, 0)\) only when \[ y > \lambda, \] in which case we have \[ \theta^* = y - \lambda. \] Conclusion. Putting together the above results we conclude that there will be a critical point if and only if \(y > \lvert \lambda \rvert\). If this condition does not hold, then the only possibility is that the solution is \(\theta^* = 0\). Suppose now that in indeed this conditions holds, then there will only be one critical point, and we will have \[ g(\theta^*) = \begin{cases} \frac{1}{2}\lambda^2 + \lambda(y + \lambda) & \text{ if } y < - \lambda\\ \frac{1}{2}\lambda^2 + \lambda(y - \lambda) & \text{ if } y > \lambda \\ \frac{1}{2} y^2 & \text{otherwise,} \end{cases} \] and after substitution of the condition on the values, we see that in both cases \[ g(\theta^*) < \frac{1}{2} y^2 = g(0). \] Altogether, we conclude \[ S_\lambda(y) = \begin{cases} y - \mathrm{sign}(y)\lambda & \text{ if } \lvert y \rvert > \lambda\\ 0 & \text{otherwise,} \end{cases} \] or even neater \[ S_\lambda(y) = \mathrm{sign}(y)(\lvert y \rvert - \lambda)_+. \]
Plotting. We now do a quick numerical investigation.
library(plotly)
lambda <- 0.4
y <- seq(-1, 1, length.out = 10)
th <- seq(-1, 1, length.out = 1000)
g <- function(y, th) 0.5 * (y - th)^2 + lambda * abs(th)
soft_thresh <- outer(y, th, g)
S_lambda <- sign(y) * pmax(abs(y) - lambda, 0)
S_lambda_value <- g(y, S_lambda)
plot_ly() %>%
add_surface(x = ~th, y = ~y, z = ~soft_thresh) %>%
add_trace(x = ~S_lambda, y = ~y, z = ~S_lambda_value, line = list(width = 10))
lambda <- 0.3
y <- seq(-1, 1, length.out = 1000)
soft <- sign(y) * pmax(abs(y) - lambda, 0)
hard <- y * (y > lambda)
plot(y, soft, type = "l", col = "red", ylim = c(-1, 1))
lines(y, hard, col = "blue")
legend("bottomright", legend = c("soft", "hard"), col = c("red", "blue"), lty = 1)
title("Soft vs Hard Thresholding with lambda = 0.3")
Here’s a simple toy example to illustrate how soft thresholding can be used to enforce sparsity in statistical models. Suppose we observe data from the following statistical model: \[ y_i \mid \theta_i \sim N(\theta_i, \sigma_i^2) \] Now suppose we believe that a lot of the \(\theta_i\)’s are zero. Consider an estimator for each \(\theta_i\) of the form \[ \hat{\theta}(y_i) = S_{\lambda\sigma_i^2}(y_i). \]
This is the sparse vector we will choose.
set.seed(110104)
n <- 100
# About half of them will be sparse
mask <- rbinom(n, 1, c(0.5))
pretheta <- seq(-3, 3, length.out = n)
theta <- pretheta * mask
We now simulate random observation, all with the same unit variance
sigma <- 1
y <- sapply(theta, function(t) rnorm(1, mean = t, sd = sigma))
Now for a grid of \(\lambda\) values and each \(y\).
lambda <- seq(0, 3, length.out = 8)
grid <- outer(y, lambda, function(y, lambda) sign(y) * pmax(abs(y) - lambda, 0))
Plot results
par(mfrow = c(2,4))
for (j in 1:8) {
plot(theta, grid[ ,j], pch = 21, bg = "blue", xlab = expression(theta), ylab = expression(hat(theta)), xlim = c(-3, 3), ylim = c(-3, 3))
abline(a = 0, b = 1, col = "red")
title(paste(expression(lambda), "=", round(lambda[j], 2)))
}
We now want to see the MSE as a function of \(\lambda\) and try different levels of sparsity.
density <- seq(0.1, 0.8, .1)
plot_data <- data.frame()
for (i in 1:8) {
mask <- rbinom(n, 1, density[i])
theta <- pretheta * mask
y <- sapply(theta, function(t) rnorm(1, mean = t, sd = sigma))
mse <- sapply(lambda, function(lambda) mean((theta - sign(y) * pmax(abs(y) - lambda, 0))^2))
plot_data <- rbind(plot_data, data.frame(
lambda = lambda,
mse = mse,
density = density[i]
))
}
# Better plotting power
library(ggplot2)
ggplot(plot_data, aes(x = lambda, y = mse, group = factor(density), colour = density)) +
geom_line() +
ggtitle("MSE varying by density") + xlab(expression(lambda))